{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep Reinforcement Learning _in Action_\n",
    "### Chapter 6\n",
    "#### Evolutionary Strategies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import gym"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Creating an Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/brandonbrown/anaconda3/envs/deeprl/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated.  Call .resolve and .require separately.\n",
      "  result = entry_point.load(False)\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('CartPole-v1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_random_agent_weights(state_space=4, action_space=2):\n",
    "    return [\n",
    "        torch.rand(state_space, 10), # fc1 weights\n",
    "        torch.rand(10),  # fc1 bias\n",
    "        torch.rand(10, 10),  # fc2 weights\n",
    "        torch.rand(10),  # fc2 bias\n",
    "        torch.rand(10, action_space),  # fc3 weights\n",
    "        torch.rand(action_space),  # fc3 bias\n",
    "    ]\n",
    "\n",
    "def get_action_from_agent_weights(agent_weight, state):\n",
    "    x = F.relu(torch.add(torch.mm(torch.Tensor(state.reshape(1,-1)), agent_weight[0]), agent_weight[1]))\n",
    "    x = F.relu(torch.add(torch.mm(x, agent_weight[2]), agent_weight[3]))\n",
    "    act_prob = F.softmax(torch.add(torch.mm(x, agent_weight[4]), agent_weight[5])).detach().numpy()[0]\n",
    "    action = np.random.choice(range(len(act_prob)), p=act_prob)\n",
    "    return action"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Determining the Agent Fitness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_fitness(env, agent_weights, max_episode_length=500, trials=5):\n",
    "    total_reward = 0\n",
    "    for _ in range(trials):\n",
    "        observation = env.reset()\n",
    "        for i in range(max_episode_length):\n",
    "            action = get_action_from_agent_weights(agent_weights, observation)\n",
    "            observation, reward, done, info = env.step(action)\n",
    "            total_reward += reward\n",
    "            if done: break\n",
    "    return total_reward / trials\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Cross"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cross(agent1_weights, agent2_weights):\n",
    "    num_params = len(agent1_weights)\n",
    "    crossover_idx = np.random.randint(0, num_params)\n",
    "    new_weights = agent1_weights[:crossover_idx] + agent2_weights[crossover_idx:]\n",
    "    new_weights = mutate(new_weights)\n",
    "    return new_weights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Mutate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mutate(new_weights):\n",
    "    num_params = len(new_weights)\n",
    "    num_params_to_update = np.random.randint(0, num_params)  # num of params to change\n",
    "    for i in range(num_params_to_update):\n",
    "        n = np.random.randint(0, num_params)\n",
    "        new_weights[n] = new_weights[n] + torch.rand(new_weights[n].size())\n",
    "    return new_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def breed(agent1_weights, agent2_weight, generation_size=10):\n",
    "    next_generation = [agent1_weights, agent2_weight]\n",
    "    for _ in range(generation_size - 2):\n",
    "        next_generation.append(cross(agent1_weights, agent2_weight))\n",
    "    return next_generation\n",
    "\n",
    "def reproduce(env, agents_weights, generation_size):\n",
    "    top_agents_weights = sorted(agents_weights, reverse=True, key=lambda a: get_fitness(env, a))[:2]\n",
    "    new_agents_weights = breed(top_agents_weights[0], top_agents_weights[1], generation_size)\n",
    "    return new_agents_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_generations = 100\n",
    "generation_size = 20\n",
    "generation_fitness = []\n",
    "max_fitness = 0\n",
    "\n",
    "agents = [init_random_agent_weights(), init_random_agent_weights()]\n",
    "\n",
    "for i in range(n_generations):\n",
    "    next_generation = reproduce(env, agents, generation_size)\n",
    "    ranked_generation = sorted([get_fitness(env, a) for a in next_generation], reverse=True)\n",
    "    avg_fitness = (ranked_generation[0] + ranked_generation[1]) / 2\n",
    "    print(i, avg_fitness)\n",
    "    generation_fitness.append(avg_fitness)\n",
    "    agents = next_generation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x11d5b2da0>]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAmcAAAF3CAYAAADgjOwXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XmYXVWd7//3N1WpzPNEJkhIgpEwE5CpWwFlUBDbAUVtULBp+2KrP+916u5f4+22fz/16dbWbq824AB0I4pKi6gMIiqzJMwJQyohoSpTDakMlUrN6/5RO1hAqnIqdaaqvF/PU8/Ze599zvnWfuDUJ2vttVaklJAkSVJ5GFHqAiRJkvRHhjNJkqQyYjiTJEkqI4YzSZKkMmI4kyRJKiOGM0mSpDJiOJMkSSojhjNJkqQyYjiTJEkqI4YzSZKkMlJZ6gIGY/r06WnBggWlLkOSJGm/Vq5c2ZBSmrG/84Z0OFuwYAErVqwodRmSJEn7FREbcjnPbk1JkqQyYjiTJEkqI4YzSZKkMmI4kyRJKiOGM0mSpDJiOJMkSSojhjNJkqQyYjiTJEkqI4YzSZKkMlLQcBYR6yPi6Yh4IiJWZMemRsTdEbEme5ySHY+I+EZEVEfEUxFxQiFrkyRJKkfFaDk7M6V0XEppebb/OeCelNIS4J5sH+B8YEn2cyXwrSLUJkmSVFZKsbbmRcCbsu3rgd8Cn82O35BSSsDDETE5ImanlDaXoEZJBbBywzZ2tnaWugxJeo1F08dz6LSxpS4DKHw4S8BdEZGA/0gpXQPM6hW4tgCzsu25QE2v19Zmx14RziLiSnpa1jj00EMLWLqkfHqqdjvv+tZDpS5Dkvbpc+cv5aNvXFTqMoDCh7MzUkobI2ImcHdEPNf7yZRSyoJbzrKAdw3A8uXLB/RaSaXz+xfqAbjpI29gTFVFiauRpFeaM3lMqUt4WUHDWUppY/ZYFxG3AicDW/d2V0bEbKAuO30jML/Xy+dlxyQNA/dXN7BszkROWzy91KVIUlkr2ICAiBgXERP2bgPnAM8AtwGXZaddBvws274NuDQbtXkKsMP7zaThoaW9k8c2bOcMg5kk7VchW85mAbdGxN7PuSmldEdEPAr8KCKuADYAF2fn/xJ4K1ANtAAfLmBtkoro0fVNtHd1c7rhTJL2q2DhLKW0Djh2H8cbgbP3cTwBVxWqHkml80B1A1UVIzhpwdRSlyJJZc8VAiQV3P1rGjjxsCkOBJCkHBjOJBVUY3Mbqzfv5IwldmlKUi4MZ5IK6sG1jQDebyZJOTKcSSqoB6obmDC6kqPnTip1KZI0JBjOJBVMSon71jRw2qJpVIyIUpcjSUOC4UxSwWxobGHj9j3ObyZJA2A4k1Qw91c3AN5vJkkDYTiTVDAPVDcwZ9JoFk4fV+pSJGnIMJxJKoiu7sSDaxs5ffF0spVCJEk5MJxJKohVm3awY0+H85tJ0gAZziQVxN77zU5bZDiTpIEwnEkqiAeqG1h6yARmTBhV6lIkaUgxnEnKu9aOLh5d3+QoTUk6AIYzSXm3Yn0T7Z3dzm8mSQfAcCYp7+6vbmBkRXDywqmlLkWShhzDmaS8e6C6geMPncK4UZWlLkWShhzDmaS8atrdzjObdtilKUkHyHAmKa8eWtdISi7ZJEkHynAmKa/ur25g/KhKjp03qdSlSNKQZDiTlFcPVDdwyuHTqKzw60WSDoTfnpLypmZbCxsaWzhj8bRSlyJJQ5bhTFLePJAt2eR6mpJ04AxnkvLm/uoGZk0cxaIZ40tdiiQNWYYzSXnR3Z14cG0jpy+eTkSUuhxJGrIMZ5LyYvXmnWzb3e78ZpI0SIYzSXmx934z5zeTpMExnEnKi/urG1gyczyzJo4udSmSNKQZziQNWmtHF4+u32armSTlgeFM0qA99lITrR3d3m8mSXlgOJM0aKs27gRg+YIpJa5EkoY+w5mkQatpamHi6Eomj60qdSmSNOQZziQNWm3THuZNGVvqMiRpWDCcSRq02qYW5k0ZU+oyJGlYMJxJGpSUEjXb9jB/qi1nkpQPhjNJg7Jtdzt7OrpsOZOkPDGcSRqUmqY9AMz3njNJygvDmaRBqW1qAWDeVFvOJCkfDGeSBqVmW0/LmaM1JSk/DGeSBqW2qYUpY0cyflRlqUuRpGHBcCZpUJzjTJLyy3AmaVBqmlqY7/1mkpQ3hjNJByylxEZbziQprwxnkg5Y/a422jq7me8cZ5KUN4YzSQds7xxntpxJUv4YziQdsJfnOLPlTJLyxnAm6YDV2nImSXlnOJN0wGqbWpg+vooxVRWlLkWShg3DmaQDVrPNkZqSlG+GM0kHrLapxfvNJCnPDGeSDkh3d2LjdlvOJCnfDGeSDsjWXa10dCVXB5CkPCt4OIuIioh4PCJuz/YXRsQjEVEdET+MiKrs+Khsvzp7fkGha5N04BypKUmFUYyWs08Az/ba/zLwtZTSYqAJuCI7fgXQlB3/WnaepDJVs61njjNXB5Ck/CpoOIuIecDbgOuy/QDOAn6cnXI98I5s+6Jsn+z5s7PzJZWhvS1ncyYbziQpnwrdcvavwGeA7mx/GrA9pdSZ7dcCc7PtuUANQPb8jux8SWWotqmFmRNGMXqkc5xJUj4VLJxFxAVAXUppZZ7f98qIWBERK+rr6/P51pIGoGbbHuZP9X4zScq3QracnQ68PSLWAzfT0535dWByRFRm58wDNmbbG4H5ANnzk4DGV79pSumalNLylNLyGTNmFLB8Sf2p3e4cZ5JUCAULZymlz6eU5qWUFgDvA36TUvoAcC/w7uy0y4CfZdu3Zftkz/8mpZQKVZ+kA9fZ1c2m7a3Md6SmJOVdKeY5+yzwqYiopueesu9kx78DTMuOfwr4XAlqk5SDLTtb6epOtpxJUgFU7v+UwUsp/Rb4bba9Djh5H+e0Au8pRj2SBqdmm3OcSVKhuEKApAGrbcrmOHN1AEnKO8OZpAGrbdpDBMyeZDiTpHwznEkasJqmFmZPHE1VpV8hkpRvfrNKGrDapj3ebyZJBWI4kzRgtduc40ySCsVwJmlA2ju72bKzlXmuDiBJBWE4kzQgW3a00p2w5UySCsRwJmlAavZOo+E9Z5JUEIYzSQOyd44zW84kqTAMZ5IGpGbbHipGBLMnjS51KZI0LBnOJA1IbVMLsyeNprLCrw9JKgS/XSUNSM8cZ3ZpSlKhGM4kDUhNU4uDASSpgAxnknLW1tnF1p1trg4gSQVkOJMOUjXbWvjRozW8sHVXzq/Z2LQHgPlT7daUpEKpLHUBkopjd1snD69r5Pcv1HPfmgbWNewG4IRDJ/PT/3F6Tu9Rm4UzW84kqXAMZ9Iw1d2dWL15J79fU8/vX6hn5YYmOroSo0eO4JTDp/HBUw6jpqmF7z2wntqmlpwC1x/DmS1nklQohjNpGGpsbuPd336IF7PWsdfPnsjlZyzkT5fMYPmCKYyqrAB6uja/98B6bn9qMx9946L9vm9NUwsjK4JZE53jTJIKxXAmDUP/fNfz1Gxr4cvvOpozl85k5oR9h6n5U8dy3PzJ/PzJTTmFs9qmPcyZPIaKEZHvkiVJGQcESMPMU7XbufnRGi47bQHvPenQPoPZXhceO4dVm3aytr55v+9ds81pNCSp0Axn0jDS3Z24+rZVTBs3ik+8eUlOr3nb0bOJgNuf3Lzfc52AVpIKz3AmlZGdrR08s3HHAb/+p49v5PGXtvPZ817HxNEjc3rNIZNGc/KCqdz25EZSSn2et6e9i4bmNsOZJBWY4UwqI1//9Rre/u/385vntg74tTtbO/jSr57j+EMn864T5g3otW8/bg5r63fz7Oa+5zzbuL0F6LlPTZJUOIYzqYzcv6aB7gQf/8ETPL8l98lhAb7x6zU07m7jf799GSMGeMP++UfNpmJE8POnNvV5To3TaEhSURjOpDJRv6uN57fu4rJTD2NsVQVXXP8ojc1tOb12zdZdfP/B9bx3+XyOmTd5wJ89dVwVZyyezs+f3NRn12bttqzlzAEBklRQhjOpTDy0rhGAd54wj2svXU79rjY++p8raevs6vd1KSW+8PNVjK2q4NPnvu6AP//CY+dQ27SHJ2q27/P52qY9VFWOYPr4UQf8GZKk/TOcSWXiobUNTBhdybI5Ezl2/mT+5eJjeXR9E3976zP93qh/xzNbeKC6kU+95QimDSI4nbNsFlWVI7jtyX13bdY0tTBv8pgBd5lKkgbGcCaViQeqG3nDwmlUVvT8b3nBMXP45JuX8OOVtVzz+3X7fM2e9i6++ItnWXrIBD54ymGD+vyJo0dy5utm8IunNtPV/dowWNu0h3kOBpCkgjOcSWWgZlsLL21r4fTF015x/BNnL+GCY2bzpTue4+7Vrx3B+a3frWXj9j184e3LXg51g3HhsXOo29XGH17c9prnnONMkopjv9/mEfGViJgYESMj4p6IqI+IDxajOOlg8dDanvvNTls0/RXHI4J/fs+xHDN3Ep+4+XFWb9r58nM121r49u/WcuGxczjl8FeGugN11tKZjK2qeM2ozd1tnWzb3e5gAEkqglz+qX1OSmkncAGwHlgMfLqQRUkHmwfXNjB9fBVHzBr/mudGj6zgmkuXM3H0SD5y/aPU7WoF4B9vX01FBH/z1qV5q2NsVSVvfv0sfvX0Zjq6ul8+Xus0GpJUNLmEs73TjL8NuCWldODTl0t6jZQSD65t5NRF04nY9832syaO5rrLlrOtpZ2/vHEld6/eyl2rt/KxsxYze1J+A9OFx86hqaWD+6sbXj5Wk02jYTiTpMLLJZz9PCKeA04E7omIGUBrYcuSDh5r65up29XG6Yv675o8au4kvnbxcTz+0nY++p8rWTBtLB/5k4V5r+dPj5jOxNGV/LzXqM3aJlcHkKRiySWcXQ2cBixPKXUALcDbC1qVdBB5sI/7zfbl/KNn8+lzX0d3Slx94TJGVVbkvZ5RlRWcd9Qh3LVqK60dPXOs1TbtYczICqaNq8r750mSXimXcPZQSmlbSqkLIKW0G/hVYcuSDh4PVDcwd/IY5k/NrcvwqjMX89jfvYUzl84sWE0XHjuH5rZOfvt8PZDNcTZlTJ/drpKk/Kns64mIOASYC4yJiOOBvd/KEwH7NqQ86OpOPLxuG+cumzWg4DOlwC1Ypx4+jWnjqvj5k5s476hDnEZDkoqoz3AGnAt8CJgHfLXX8V3A3xSwJumgsXrTTnbs6cipS7OYKitG8NajZ3PLyhqa2zqp2dbCiYdNKXVZknRQ6DOcpZSuB66PiHellH5SxJqkg8aDa3tGRJ62n8EApXDhsXO48eEN3PpYLTtbO205k6Qi6a/lbK/bI+L9wILe56eU/qFQRUkHiwfXNrJ45nhmThxd6lJeY/lhU5g9aTTX3NezdNQ8J6CVpKLIZUDAz4CLgE5gd68fSYPQ3tnNH17ctt8pNEplxIjggmNmU7OtZwJaVweQpOLIpeVsXkrpvIJXIh1knqzdzp6OLk4ts/vNervw2Dlce9+LgBPQSlKx5NJy9mBEHF3wSqSDzAPVDUTAKYdPLXUpfTp67iQOmzaW8aMqmTx25P5fIEkatFxazs4APhQRLwJt9EypkVJKxxS0MmmYe3BtI0fNmcTkseU7sWtE8Km3HEF1XbNznElSkeQSzs4veBXSQWZPexePv9TE5afnf/mlfLvouLmlLkGSDir77dZMKW0A5gNnZdstubxOUt8eXb+Njq7EaYvL934zSVJp7DdkRcTVwGeBz2eHRgL/WciipOHuwbWNjKwITlrgxK6SpFfKpQXsz+hZ6Hw3QEppEzChkEVJw92Daxs4fv4UxlblcmeBJOlgkks4a08pJSABRMS4wpYkDW87Wjp4ZuMOTi3T+c0kSaWVSzj7UUT8BzA5Iv4C+DVwbWHLkoavh19spDuV55JNkqTS22+fSkrpnyPiLcBO4HXA36eU7i54ZdIw9dDaRkaPHMHxh3q/mSTptfoNZxFRAfw6pXQmYCCT8uDBtQ2ctGAqVZUOepYkvVa/fx1SSl1Ad0RMKlI90rBWt6uVF7Y2c7pTaEiS+pDLULFm4OmIuJteC56nlD7e34siYjTwe2BU9jk/TildHRELgZuBacBK4M9TSu0RMQq4ATgRaATem1JaP/BfSSpfD61tBLzfTJLUt1zC2U+zn4Fqo2fi2uaIGAncHxG/Aj4FfC2ldHNEfBu4AvhW9tiUUlocEe8Dvgy89wA+VypbD1Y3MnF0Jcvm2BgtSdq3XAYEXH8gb5xNv9Gc7Y7MfhJwFvD+7Pj1wBfoCWcXZdsAPwb+PSIiex9pWHhwXQOnHD6NihGuUylJ2rdcVghYEhE/jojVEbFu708ubx4RFRHxBFBHz4CCtcD2lFJndkotsHfhvrlADUD2/A56uj6lYaFmWws12/bYpSlJ6lcuw8W+R0/LVidwJj33heW0fFNKqSuldBwwDzgZWHqAdb4sIq6MiBURsaK+vn6wbycVzYNrGwAcDCBJ6lcu4WxMSukeIFJKG1JKXwDeNpAPSSltB+4FTqVnMtu93anzgI3Z9kZ6Flgne34SPQMDXv1e16SUlqeUls+YMWMgZUgl09HVze1PbWbGhFEsnjm+1OVIkspYLuGsLSJGAGsi4mMR8WfAfv+6RMSMiJicbY8B3gI8S09Ie3d22mXAz7Lt27J9sud/4/1mGg6adrdz6Xf+wH1rGrjijIVEeL+ZJKlvuYzW/AQwFvg48I/0dG1e1u8reswGrs8msh0B/CildHtErAZujogvAo8D38nO/w5wY0RUA9uA9w3oN5HK0PNbdvGRGx5l6842vnrxsbzzhHmlLkmSVOb2t0LADHpGWFamlGqBD+f6ximlp4Dj93F8HT33n736eCvwnlzfXyp3d6/eyidvfpxxoyr54ZWnuFyTJCknfXZrRsRHgFXAvwHPRcTbi1aVNISllPjmvdVceeMKFs0cz20fO8NgJknKWX8tZ58ElqWU6iPicOC/6LkvTFIf9rR38ZmfPMXPn9zERcfN4cvvOobRIytKXZYkaQjpL5y1p5TqoacrMlteSVIfNu/Yw5U3rOSZTTv47HlL+egbD/fmf0nSgPUXzuZFxDf62t/f2prSwaKrO/HIi418/AdP0NrRxXWXLufs188qdVmSpCGqv3D26VftryxkIdJQ0NnVzZq6Zp7ZuKPnZ9NOVm/ayZ6OLg6dOpab/uINHDFrQqnLlCQNYX2GswNdU1MaTlJK/PLpLTy0roGnN+7kuc07aevsBmBcVQXL5kzifSfP56g5k3jzkbOYNGZkiSuWJA11ucxzJh207lvTwFU3PcaEUZUsmzuRS089jKPmTuKouZNYOG0cI1zAXJKUZ4YzqR+3PbmJCaMqefTv3uyoS0lSUfQ3z9mXs0cnhtVBqa2ziztXbeEty2YZzCRJRdPf2ppvjZ55AD5frGKkcnLfCw3sau3kwmPmlLoUSdJBpL9uzTuAJmB8ROwEgp6lnAJIKaWJRahPKpnbn9rEpDEjOX3x9FKXIkk6iPTZcpZS+nRKaTLwi5TSxJTShN6PRaxRKrrWji7uXr2V85YdQlVlfw3MkiTl134HBKSULoqIWcBJ2aFH9q4cIA1Xv32+jt3tXVxw7OxSlyJJOsjst0kgGxDwB+A9wMXAHyLi3YUuTCqlnz+1mWnjqjj18GmlLkWSdJDJZSqNvwNOSinVAUTEDODXwI8LWZhUKi3tnfzm2TreecJcKivs0pQkFVcuf3lG7A1mmcYcXycNSfc8W8eeji4ucJSmJKkEcmk5uyMi7gR+kO2/F/hl4UqSSuv2pzYxY8IoTl44tdSlSJIOQrkMCPh0RLwTOCM7dE1K6dbCliWVxq7WDu59vp73n3woFS7NJEkqgZyWb0op/RT4aYFrkUru189upb2zmwuOcZSmJKk0vHdM6uX2Jzcze9JoTjh0SqlLkSQdpAxnUmZHSwe/X1PP246ezQi7NCVJJTKgcBYRUyLimEIVI5XSnau30NGVuOBYR2lKkkonl0lofxsREyNiKvAYcG1EfLXwpUnFdftTm5k/dQzHzptU6lIkSQexXFrOJqWUdgLvBG5IKb0BeHNhy5KKa9vudh6obuBtR88hwi5NSVLp5BLOKiNiNj1LN91e4HqkkrjjmS10dSdHaUqSSi6XcPYPwJ1AdUrp0Yg4HFhT2LKk4rr9qU0snD6OZXMmlroUSdJBLpdJaG8Bbum1vw54VyGLkoqpflcbD69r5KozF9ulKUkquVwGBHwlGxAwMiLuiYj6iPhgMYqTiuFXz2ymO+FampKkspBLt+Y52YCAC4D1wGLg04UsSiqm25/czJKZ43ndIRNKXYokSbkNCMge3wbcklLaUcB6pKLasqOVRzdss9VMklQ2cllb8/aIeA7YA/xVRMwAWgtbllQcv3h6MynBBcc6SlOSVB7223KWUvoccBqwPKXUAbQAFxW6MKkYbn9qE6+fPZFFM8aXuhRJkoDcBgSMBf4H8K3s0BxgeSGLkoqhtqmFx1/a7txmkqSykss9Z98D2ulpPQPYCHyxYBVJRXLvc3UAvO1ow5kkqXzkEs4WpZS+AnQApJRaACeD0pD3/NZdTBhdyWHTxpa6FEmSXpZLOGuPiDFAAoiIRUBbQauSiqC6rpnFM8c78awkqazkEs6uBu4A5kfEfwH3AJ8paFVSEVTX7WaxAwEkSWUml+Wb7o6Ix4BT6OnO/ERKqaHglUkFtL2lnYbmNpbMMpxJkspLLvOcAYwGmrLzj4wIUkq/L1xZUmFV1zUDsHim4UySVF72G84i4svAe4FVQHd2OAGGMw1ZL4ezGS7ZJEkqL7m0nL0DeF1KyUEAGjaq65oZVTmCuVPGlLoUSZJeIZcBAeuAkYUuRCqmNXXNLJoxnooRjtSUJJWXXFrOWoAnIuIeek2hkVL6eMGqkgqsuq6ZEw+bUuoyJEl6jVzC2W3ZT2+pALVIRdHS3snG7Xt470nzS12KJEmvkUs4m5xS+nrvAxHxiQLVIxXcuvrdgCM1JUnlKZd7zi7bx7EP5bkOqWjW1O0CYInhTJJUhvpsOYuIS4D3Awsjone35gRgW6ELkwqluq6ZihHBYdPGlboUSZJeo79uzQeBzcB04F96Hd8FPFXIoqRCqq5r5rBpY6mqzKXhWJKk4uoznKWUNgAbgFOLV45UeNV1za6pKUkqW302HUTE/dnjrojY2etnV0TsLF6JUv60d3azvrHFNTUlSWWrv5azM7JH17fRsLGhcTdd3cmRmpKkstVfy9k7e20PeLbOiJgfEfdGxOqIWLV3+o2ImBoRd0fEmuxxSnY8IuIbEVEdEU9FxAkH8gtJ/XFNTUlSuevvjui/67V9zwG8dyfwP1NKRwKnAFdFxJHA54B7UkpLsvf9XHb++cCS7OdK4FsH8JlSv/aGs0UzHakpSSpP/YWz6GM7JymlzSmlx7LtXcCzwFzgIuD67LTr6VlYnez4DanHw8DkiJg90M+V+rOmrpm5k8cwtiqX+ZclSSq+/v5CjYmI4+kJcKOz7ZdD2t7glYuIWAAcDzwCzEopbc6e2gLMyrbnAjW9XlabHduMlCfVdc3ebyZJKmv9hbPNwFez7S29tqFnbc2zcvmAiBgP/AT4ZEppZ8QfG+FSSikiBrROZ0RcSU+3J4ceeuhAXqqDXHd3Yl1DM6cumlbqUiRJ6lN/ozXPHOybR8RIeoLZf6WUfpod3hoRs1NKm7Nuy7rs+Eag90rU87Jjr67rGuAagOXLl7sAu3K2cfseWju6bTmTJJW1gk2RHj1NZN8Bnk0p9W51u40/rtd5GfCzXscvzUZtngLs6NX9KQ2aa2pKkoaCQt4VfTrw58DTEfFEduxvgC8BP4qIK+hZgeDi7LlfAm8FqoEW4MMFrE0HoZen0TCcSZLKWMHCWUrpfvoe5Xn2Ps5PwFWFqkeqrmtm+vgqJo+tKnUpkiT1ab/hrI/JYHcAG1JKnfkvSSqM6rpmFrmmpiSpzOXScvZ/gBOAp+hpCTsKWAVMioi/SindVcD6pLxIKbGmrpmLjptT6lIkSepXLgMCNgHHp5SWp5ROpGe+snXAW4CvFLI4KV/qd7Wxq7WTxbacSZLKXC7h7IiU0qq9Oyml1cDSlNK6wpUl5dcfBwO4pqYkqbzl0q25KiK+Bdyc7b8XWB0Ro4COglUm5VF1vSM1JUlDQy4tZx+iZ3qLT2Y/67JjHcCgJ6qVimHN1mYmjKpk1sRRpS5FkqR+5dJydj7w7ymlf9nHc815rkcqiOq6ZhbNHE/v5cMkSSpHubScXQi8EBE3RsQFEVHIiWulgqiud8FzSdLQsN9wllL6MLAYuAW4BFgbEdcVujApX3a0dFC/q81lmyRJQ0JOrWAppY6I+BWQgDHAO4CPFLIwKV+q63vW1LTlTJI0FOy35Swizo+I7wNrgHcB1wGHFLguKW9cU1OSNJTk0nJ2KfBD4C9TSm0FrkfKu+q6ZqoqRzBvythSlyJJ0n7tN5yllC7pvR8RZwCXpJRcpFxDwt41NStGOFJTklT+crrnLCKOB94PvAd4EfhpIYuS8mlNXTPHHzql1GVIkpSTPsNZRBxBz+jMS4AGero2I6XkxLMaMva0d7Fx+x7ec+L8UpciSVJO+ms5ew64D7ggpVQNEBH/T1GqkvJkbX0zKTkYQJI0dPQ3WvOdwGbg3oi4NiLOBrxpR0PK2mxNzSWzDGeSpKGhz3CWUvrvlNL7gKXAvfSsqzkzIr4VEecUq0BpMNZsbaZiRLBg2rhSlyJJUk5yWSFgd0rpppTShcA84HHgswWvTMqD6rpmDps6lqrKXFYqkySp9Ab0Fyul1JRSuialdHahCpLyqbq+Z8FzSZKGCpsTNGx1dHWzvmG3a2pKkoYUw5mGrQ2Nu+nsTo7UlCQNKYYzDVuuqSlJGooMZxq29oazRTMMZ5KkocNwpmGruq6ZuZPHMG5UTquUSZJUFgxnGrbW1DlSU5I09BjONCx1dyfW1jez2C5NSdIQYzjTsLRx+x5aO7odDCBJGnIMZxqWql1TU5I0RBnONCxVb82m0bBbU5I0xBjONCw9v3UX08ZVMWVcValLkSRpQAxnGna6uhO/fb6eUw6fVupSJEkaMMOZhp0nappoaG7jnGWkB1t+AAASKUlEQVSzSl2KJEkDZjjTsHPnqq2MrAjOXDqz1KVIkjRghjMNKykl7ly1hVMXTWfi6JGlLkeSpAEznGlYeWFrMxsaWzjnSLs0JUlDk+FMw8pdq7YAGM4kSUOW4UzDyp2rt3D8oZOZOXF0qUuRJOmAGM40bGzcvodnNu7k3GWHlLoUSZIOmOFMw4ZdmpKk4cBwpmHjrlVbWTJzPIe7ZJMkaQgznGlYaNrdzh/Wb3PiWUnSkGc407Bwz3N1dHUn7zeTJA15hjMNC3eu2sLsSaM5eu6kUpciSdKgGM405O1p7+K+NfWcc+QsIqLU5UiSNCiGMw15v3uhntaObrs0JUnDguFMQ95dq7cwacxITlo4tdSlSJI0aIYzDWkdXd3c82wdZ79+JiMr/M9ZkjT0+ddMQ9ofXtzGjj0dnHOkXZqSpOHBcKYh7a5VWxg9cgRvPGJGqUuRJCkvDGcaslJK3LV6K3+yZAZjqipKXY4kSXlhONOQ9fTGHWze0eooTUnSsFKwcBYR342Iuoh4ptexqRFxd0SsyR6nZMcjIr4REdUR8VREnFCoujR83LlqCxUjgrOXzix1KZIk5U0hW86+D5z3qmOfA+5JKS0B7sn2Ac4HlmQ/VwLfKmBdGibuWrWVkxdMZcq4qlKXIklS3hQsnKWUfg9se9Xhi4Drs+3rgXf0On5D6vEwMDkiZheqNg196+qbWVPX7ELnkqRhp9j3nM1KKW3OtrcAe/+yzgVqep1Xmx2T9umu1VsBOMf7zSRJw0zJBgSklBKQBvq6iLgyIlZExIr6+voCVKah4M5VWzhq7kTmTh5T6lIkScqrYoezrXu7K7PHuuz4RmB+r/PmZcdeI6V0TUppeUpp+YwZzm11MKrb2crjL23nXCeelSQNQ8UOZ7cBl2XblwE/63X80mzU5inAjl7dn9Ir2KUpSRrOKgv1xhHxA+BNwPSIqAWuBr4E/CgirgA2ABdnp/8SeCtQDbQAHy5UXRraWju6uPXxjSyYNpYjZo0vdTmSJOVdwcJZSumSPp46ex/nJuCqQtWi4WFtfTMfu+lxnt28k398x1FERKlLkiQp7woWzqR8uvXxWv721mcYVTmC735oOWctdQoNSdLwZDhTWWtp7+Tqn63ilpW1nLxgKl+/5DhmT3KEpiRp+DKcqWy9sHUXV/3XY1TXN/PXZy3mE2cvobLC5WAlScOb4UxlJ6XEj1bUcPVtqxg/aiQ3Xv4GzlgyvdRlSZJUFIYzlZXmtk7+9tan+dkTmzh98TS+9t7jmDlhdKnLkiSpaAxnKhsbt+/hz697hPWNu/lf5xzBX71pMRUjHJEpSTq4GM5UFjZt38P7rnmI7S0d3PQXp3DK4dNKXZIkSSVhOFPJbd6xh0uufZjtuzu48SNv4Lj5k0tdkiRJJWM4U0lt3dnK+699hMbmdm644mSDmSTpoOe8BCqZup2tXHLNw9TtbOX6y0/ihEOnlLokSZJKzpYzlUTdrlYuufZhtuxs5frLT+bEw6aWuiRJksqCLWcquobmNj5w7SNs2t7K9z50EictMJhJkrSX4UxF1djcxvuvfZiapha++6GTeIOjMiVJegXDmYpm2+52PnDdI2xobOG7l53EqYsMZpIkvZrhTEWxo6WDD1z3CC827OY7l53EaYtdjkmSpH0xnKngOru6+dgPHqO6bhfXXrrcdTIlSeqHozVVcF/61XPct6aBr7zrGP70iBmlLkeSpLJmy5kK6icra7nu/hf50GkLuPik+aUuR5Kksmc4U8E8UbOdz9/6NKctmsbfvu31pS5HkqQhwXCmgti6s5Urb1jBrImj+Ob7T2Bkhf+pSZKUC/9iKu9aO7r4yxtX0tzWybWXLmfKuKpSlyRJ0pDhgADlVUqJv731GZ6o2c63P3giSw+ZWOqSJEkaUmw5U79aO7po7+zO+fzvPrCenzxWyyffvITzjjqkgJVJkjQ82XImuroTm7bvYV3Dbl6sb2Zdw27W1e/mxYbdbNy+h3FVFbzpdTM5Z9kszlw6k4mjR+7zfe5f08A//WI15y6bxcfPWlLk30KSpOHBcHaQerJmOzc8tIFnNu7gxcbdr2gdmzCqksNnjOOkBVO4ePp8tuxs5e7VW/nF05sZWRGctmg65yybxVuOnMXMCaMBWN+wm6tueowlMyfw1YuPY8SIKNWvJknSkBYppVLXcMCWL1+eVqxYUeoyhoyu7sSvn93Kdfet49H1TUwYVckbDp/KwunjOHzGeA6fPo6FM8YxY/woIuI1r32ipok7V23lzlVb2NDYQgSccOgUzl02i1tW1FLf3MZtV53BodPGlug3lCSpfEXEypTS8v2eZzgb/lraO7llRS3ffeBFNjS2MHfyGC4/YyEXL5/HhD66KPuTUuL5rbu4KwtqqzbtpGJEcOPlJ7tmpiRJfcg1nNmtOYxt2dHK9Q+t56ZHXmLHng6Omz+Zz5y7lHOXzaJyEPOORQRLD5nI0kMm8vGzl1CzrYVdrZ0cOceRmZIkDZbhbBjasqOVr9z5HLc9sYnulDh32SF85E8O58TDphTk8+ZPtRtTkqR8MZwNIyklbn60hv/vF8/S0d3NB085jMtPX+g9YJIkDSGGs2FiQ+NuPveTp3loXSOnHj6NL73raA6bNq7UZUmSpAEynA1xXd2J7z3wIv981/OMHDGC//+dR/O+k+a/ZrSlJEkaGgxnQ9jzW3bxmZ88xZM12zl76Uy++GdHMXvSmFKXJUmSBsFwNgS1d3bzf35bzTfvrWbC6JF8/X3H8fZj59haJknSMGA4G2JebNjNR29cyfNbd3HRcXP4+wuOZNr4UaUuS5Ik5YnhbAjZ0dLB5d9/lB17Orju0uW8+chZpS5JkiTlmeFsiOjs6uaqmx6jtqmFH/zFKSxfMLXUJUmSpAIwnA0RX/zFs9xf3cBX3n2MwUySpGHswNfwUdHc/IeX+P6D67nijIVcvHx+qcuRJEkFZDgrc4+u38b/+7Nn+NMjZvD585eWuhxJklRghrMyVtvUwkdvXMn8KWP5t0uOH9Ri5ZIkaWjwr32Z2t3WyUeuX0F7VzfXXracSWNGlrokSZJUBIazMtTdnfjUj57gha27+Pf3n8CiGeNLXZIkSSoSw1kZ+td71nDnqq38zVtfzxuPmFHqciRJUhE5lUaRvNTYwu72TmZMGMWUsVVUjNj3Uku/eGoz37hnDe85cR5XnLGwyFVKkqRSM5wVSEqJ57bs4o5ntnDHM1t4fuuul58bETBt/Cimjx/FjAmjmDF+FNMnVDFx9Ej+7TdrOPGwKXzxz45yrUxJkg5ChrM86u5OPFm7nTtWbeHOZ7awvrGFCDhpwVT+/oIjOWTSaOp3tdHQ3Eb9rraXt6u37qKhuZ32rm7mTx3Dtz94IqMqK0r960iSpBIwnA1QSomW9i52t3Wyq62T3W2dNO5u53fP13PHM1vYsrOVyhHBaYun85dvXMSbXz+LGRP2vzB5SomdezoZU1VBVaW3AkqSdLAynPXjobWNfPmO52jOQlhzaye72zvpTq89d1TlCN54xAw+e/TrOGvprAFPfRERTBrrdBmSJB3sDGf9qKocwcQxI5kzeTTjqioZP7qS8aN6fsaNqmTC6ErGVfU8Hj1vEmOrvJySJGlwTBP9OPGwKdxw+cmlLkOSJB1Eyurmpog4LyKej4jqiPhcqeuRJEkqtrIJZxFRAXwTOB84ErgkIo4sbVWSJEnFVTbhDDgZqE4prUsptQM3AxeVuCZJkqSiKqdwNheo6bVfmx2TJEk6aJRTOMtJRFwZESsiYkV9fX2py5EkScqrcgpnG4H5vfbnZcdeIaV0TUppeUpp+YwZLgouSZKGl3IKZ48CSyJiYURUAe8DbitxTZIkSUVVNvOcpZQ6I+JjwJ1ABfDdlNKqEpclSZJUVGUTzgBSSr8EflnqOiRJkkqlnLo1JUmSDnqGM0mSpDJiOJMkSSojhjNJkqQyEimlUtdwwCKiHthQ4I+ZDjQU+DP0Sl7z4vJ6F5fXu7i83sXnNe/bYSml/U7SOqTDWTFExIqU0vJS13Ew8ZoXl9e7uLzexeX1Lj6v+eDZrSlJklRGDGeSJEllxHC2f9eUuoCDkNe8uLzexeX1Li6vd/F5zQfJe84kSZLKiC1nkiRJZcRw1o+IOC8ino+I6oj4XKnrGW4i4rsRURcRz/Q6NjUi7o6INdnjlFLWOJxExPyIuDciVkfEqoj4RHbca14gETE6Iv4QEU9m1/x/Z8cXRsQj2XfLDyOiqtS1DicRURERj0fE7dm+17tAImJ9RDwdEU9ExIrsmN8pg2Q460NEVADfBM4HjgQuiYgjS1vVsPN94LxXHfsccE9KaQlwT7av/OgE/mdK6UjgFOCq7L9pr3nhtAFnpZSOBY4DzouIU4AvA19LKS0GmoArSljjcPQJ4Nle+17vwjozpXRcr+kz/E4ZJMNZ304GqlNK61JK7cDNwEUlrmlYSSn9Htj2qsMXAddn29cD7yhqUcNYSmlzSumxbHsXPX+85uI1L5jUoznbHZn9JOAs4MfZca95HkXEPOBtwHXZfuD1Lja/UwbJcNa3uUBNr/3a7JgKa1ZKaXO2vQWYVcpihquIWAAcDzyC17ygsi62J4A64G5gLbA9pdSZneJ3S379K/AZoDvbn4bXu5AScFdErIyIK7NjfqcMUmWpC5D6klJKEeFw4jyLiPHAT4BPppR29jQs9PCa519KqQs4LiImA7cCS0tc0rAVERcAdSmllRHxplLXc5A4I6W0MSJmAndHxHO9n/Q75cDYcta3jcD8XvvzsmMqrK0RMRsge6wrcT3DSkSMpCeY/VdK6afZYa95EaSUtgP3AqcCkyNi7z+O/W7Jn9OBt0fEenpuRTkL+Dpe74JJKW3MHuvo+cfHyfidMmiGs749CizJRvlUAe8DbitxTQeD24DLsu3LgJ+VsJZhJbv35jvAsymlr/Z6ymteIBExI2sxIyLGAG+h516/e4F3Z6d5zfMkpfT5lNK8lNICer6zf5NS+gBe74KIiHERMWHvNnAO8Ax+pwyak9D2IyLeSs/9CxXAd1NK/1TikoaViPgB8CZgOrAVuBr4b+BHwKHABuDilNKrBw3oAETEGcB9wNP88X6cv6HnvjOveQFExDH03BBdQc8/hn+UUvqHiDicnpadqcDjwAdTSm2lq3T4ybo1/1dK6QKvd2Fk1/XWbLcSuCml9E8RMQ2/UwbFcCZJklRG7NaUJEkqI4YzSZKkMmI4kyRJKiOGM0mSpDJiOJMkSSojhjNJQ0ZEzIqImyJiXbZczEMR8WclquVNEXFar/2PRsSlpahF0vDi8k2ShoRsEt3/Bq5PKb0/O3YY8PYCfmZlrzUZX+1NQDPwIEBK6duFqkPSwcV5ziQNCRFxNvD3KaU37uO5CuBL9ASmUcA3U0r/kU1E+gWgATgKWEnPBKQpIk4EvgqMz57/UEppc0T8FngCOAP4AfAC8HdAFdAIfAAYAzwMdAH1wF8DZwPNKaV/jojjgG8DY+lZ6PzylFJT9t6PAGcCk4ErUkr35e8qSRoO7NaUNFQsAx7r47krgB0ppZOAk4C/iIiF2XPHA58EjgQOB07P1hj9N+DdKaUTge8CvVcAqUopLU8p/QtwP3BKSul4emaZ/0xKaT094etrKaXj9hGwbgA+m1I6hp4VGa7u9VxlSunkrKarkaRXsVtT0pAUEd+kp3WrnZ4lYo6JiL3rJ04ClmTP/SGlVJu95glgAbCdnpa0u3t6S6kANvd6+x/22p4H/DBbwLkKeHE/dU0CJqeUfpcduh64pdcpexecX5nVIkmvYDiTNFSsAt61dyeldFVETAdWAC8Bf51SurP3C7Juzd5rKHbR870XwKqU0ql9fNbuXtv/Bnw1pXRbr27Swdhbz95aJOkV7NaUNFT8BhgdEX/V69jY7PFO4K+y7koi4oiIGNfPez0PzIiIU7PzR0bEsj7OnQRszLYv63V8FzDh1SenlHYATRHxJ9mhPwd+9+rzJKkv/qtN0pCQ3cT/DuBrEfEZem7E3w18lp5uwwXAY9moznrgHf28V3vWBfqNrBuyEvhXelrnXu0LwC0R0URPQNx7L9vPgR9HxEX0DAjo7TLg2xExFlgHfHjgv7Gkg5WjNSVJksqI3ZqSJEllxHAmSZJURgxnkiRJZcRwJkmSVEYMZ5IkSWXEcCZJklRGDGeSJEllxHAmSZJURv4vcjkJItaFK18AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 720x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(10,6))\n",
    "plt.ylabel(\"Avg Fitness of Parents\")\n",
    "plt.xlabel(\"Generation\")\n",
    "plt.plot(generation_fitness)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test trained agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/brandonbrown/anaconda3/envs/deeprl/lib/python3.6/site-packages/ipykernel/__main__.py:14: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    }
   ],
   "source": [
    "state = torch.from_numpy(env.reset()).float()\n",
    "done = False\n",
    "for i in range(200):\n",
    "    action = get_action_from_agent_weights(agents[0],state)\n",
    "    state, reward, done, info = env.step(action)\n",
    "    state = torch.from_numpy(state).float()\n",
    "    if done:\n",
    "        print(\"Game over at time step {}\".format(i,))\n",
    "        break\n",
    "    env.render()\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
